{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Federated Learning - SMS spam prediction with a GRU model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial you are going to see how you can leverage on PySyft and PyTorch to train a 1-layer GRU model using Federated Learning.\n",
    "\n",
    "The data used for this project was the [SMS Spam Collection Data Set](https://archive.ics.uci.edu/ml/datasets/sms+spam+collection) available on the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php). The dataset consists of c. 5500 SMS messages, of which around 13% are spam messages.\n",
    "\n",
    "The objective here is to simulate two remote machines (that we will call Bob and Anne), where each machine have a similar number of labeled datapoints (SMS labeled as spam or not). \n",
    "\n",
    "**Author**: André Macedo Farias. Github: [@andrelmfarias](https://github.com/andrelmfarias) | Twitter: [@andrelmfarias](https://twitter.com/andrelmfarias)\n",
    "\n",
    "*I also wrote a blogpost about this tutorial and PySyft, feel free to check it out: [Private AI — Federated Learning with PySyft and PyTorch](https://towardsdatascience.com/private-ai-federated-learning-with-pysyft-and-pytorch-954a9e4a4d4e)*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Useful imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we are most interested in the usage of PySyft and Federated Learning, I will skip the text-preprocessing part of the project. If you are interested in how I performed the preprocessing of the raw dataset you can take a look on the script [preprocess.py](https://github.com/OpenMined/PySyft/tree/dev/examples/data/SMS-spam/preprocess.py).\n",
    "\n",
    "Each data point of the `inputs.npy` dataset correspond to an array of 30 tokens obtained form each message (padded at left or truncated at right)\n",
    "\n",
    "The `label.npy` dataset has the following unique values: `1` for `spam` and `0` for `non-spam`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = np.load('./data/inputs.npy')\n",
    "labels = np.load('./data/labels.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "VOCAB_SIZE = int(inputs.max()) + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training model with Federated learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training and model hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training params\n",
    "EPOCHS = epochs\n",
    "CLIP = 5 # gradient clipping - to avoid gradient explosion (frequent in RNNs)\n",
    "lr = 0.1\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "# Model params\n",
    "EMBEDDING_DIM = 50\n",
    "HIDDEN_DIM = 10\n",
    "DROPOUT = 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initiating virtual workers with Pysyft"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this part we are going to separate the dataset in training and test sets following the ratio 80/20. Each of these datasets will be split in two and will be sent to \"Bob's\" and \"Anne's\" machines in order to **simulate remote and private data**.\n",
    "\n",
    "Please note that in a real case, such datasets will be already in the remote machines and the preprocessing will be performed before hand by their own devices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = torch.tensor(labels)\n",
    "inputs = torch.tensor(inputs)\n",
    "\n",
    "# splitting training and test data\n",
    "pct_test = 0.2\n",
    "\n",
    "train_labels = labels[:-int(len(labels)*pct_test)]\n",
    "train_inputs = inputs[:-int(len(labels)*pct_test)]\n",
    "\n",
    "test_labels = labels[-int(len(labels)*pct_test):]\n",
    "test_inputs = inputs[-int(len(labels)*pct_test):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hook that extends the Pytorch library to enable all computations with pointers of tensors sent to other workers\n",
    "hook = sy.TorchHook(torch)\n",
    "\n",
    "# Creating 2 virtual workers\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "anne = sy.VirtualWorker(hook, id=\"anne\")\n",
    "\n",
    "# threshold indexes for dataset split (one half for Bob, other half for Anne)\n",
    "train_idx = int(len(train_labels)/2)\n",
    "test_idx = int(len(test_labels)/2)\n",
    "\n",
    "# Sending toy datasets to virtual workers\n",
    "bob_train_dataset = sy.BaseDataset(train_inputs[:train_idx], train_labels[:train_idx]).send(bob)\n",
    "anne_train_dataset = sy.BaseDataset(train_inputs[train_idx:], train_labels[train_idx:]).send(anne)\n",
    "bob_test_dataset = sy.BaseDataset(test_inputs[:test_idx], test_labels[:test_idx]).send(bob)\n",
    "anne_test_dataset = sy.BaseDataset(test_inputs[test_idx:], test_labels[test_idx:]).send(anne)\n",
    "\n",
    "# Creating federated datasets, an extension of Pytorch TensorDataset class\n",
    "federated_train_dataset = sy.FederatedDataset([bob_train_dataset, anne_train_dataset])\n",
    "federated_test_dataset = sy.FederatedDataset([bob_test_dataset, anne_test_dataset])\n",
    "\n",
    "# Creating federated dataloaders, an extension of Pytorch DataLoader class\n",
    "federated_train_loader = sy.FederatedDataLoader(federated_train_dataset, shuffle=True, batch_size=BATCH_SIZE)\n",
    "federated_test_loader = sy.FederatedDataLoader(federated_test_dataset, shuffle=False, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating simple GRU (1-layer) model with sigmoid activation for classification task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For educational purposes, we built a handcrafted GRU with linear layers whose architecture and code you can check on [handcrafted_GRU.py](https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/advanced/Federated%20SMS%20Spam prediction/handcrafted_GRU.py)\n",
    "\n",
    "As the focus of this notebook is the usage of Federated Learning with PySyft, we not show the construction of the model here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from handcrafted_GRU import GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initiating the model\n",
    "model = GRU(vocab_size=VOCAB_SIZE, hidden_dim=HIDDEN_DIM, embedding_dim=EMBEDDING_DIM, dropout=DROPOUT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training and validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining loss and optimizer\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each epoch we are going to compute the training and validations losses, as well as the [Area Under the ROC Curve](https://scikit-learn.org/stable/modules/model_evaluation.html#roc-metrics) score due to the fact that the target dataset is unbalaced (only 13% of labels are positive)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "for e in range(EPOCHS):\n",
    "    \n",
    "    ######### Training ##########\n",
    "    \n",
    "    losses = []\n",
    "    # Batch loop\n",
    "    for inputs, labels in federated_train_loader:\n",
    "        # Location of current batch\n",
    "        worker = inputs.location\n",
    "        # Initialize hidden state and send it to worker\n",
    "        h = torch.Tensor(np.zeros((BATCH_SIZE, HIDDEN_DIM))).send(worker)\n",
    "        # Send model to current worker\n",
    "        model.send(worker)\n",
    "        # Setting accumulated gradients to zero before backward step\n",
    "        optimizer.zero_grad()\n",
    "        # Output from the model\n",
    "        output, _ = model(inputs, h)\n",
    "        # Calculate the loss and perform backprop\n",
    "        loss = criterion(output.squeeze(), labels.float())\n",
    "        loss.backward()\n",
    "        # Clipping the gradient to avoid explosion\n",
    "        nn.utils.clip_grad_norm_(model.parameters(), CLIP)\n",
    "        # Backpropagation step\n",
    "        optimizer.step() \n",
    "        # Get the model back to the local worker\n",
    "        model.get()\n",
    "        losses.append(loss.get())\n",
    "    \n",
    "    ######## Evaluation ##########\n",
    "    \n",
    "    # Model in evaluation mode\n",
    "    model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        test_preds = []\n",
    "        test_labels_list = []\n",
    "        eval_losses = []\n",
    "\n",
    "        for inputs, labels in federated_test_loader:\n",
    "            # get current location\n",
    "            worker = inputs.location\n",
    "            # Initialize hidden state and send it to worker\n",
    "            h = torch.Tensor(np.zeros((BATCH_SIZE, HIDDEN_DIM))).send(worker)    \n",
    "            # Send model to worker\n",
    "            model.send(worker)\n",
    "            \n",
    "            output, _ = model(inputs, h)\n",
    "            loss = criterion(output.squeeze(), labels.float())\n",
    "            eval_losses.append(loss.get())\n",
    "            preds = output.squeeze().get()\n",
    "            test_preds += list(preds.numpy())\n",
    "            test_labels_list += list(labels.get().numpy().astype(int))\n",
    "            # Get the model back to the local worker\n",
    "            model.get()\n",
    "        \n",
    "        score = roc_auc_score(test_labels_list, test_preds)\n",
    "    \n",
    "    print(\"Epoch {}/{}...  \\\n",
    "    AUC: {:.3%}...  \\\n",
    "    Training loss: {:.5f}...  \\\n",
    "    Validation loss: {:.5f}\".format(e+1, EPOCHS, score, sum(losses)/len(losses), sum(eval_losses)/len(eval_losses)))\n",
    "    \n",
    "    model.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Well Done!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Et voilà! You have just trained a model for a real world application (SMS spam classifier) using Federated Learning!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that with the PySyft library and its PyTorch extension, you can perform operations with tensor pointers such as you can do with PyTorch API. \n",
    "\n",
    "Thanks to this, you were able to train spam detector model without having any access to the remote and private data: for each batch you sent the model to the current remote worker and got it back to the local machine before sending it to the worker of the next batch.\n",
    "\n",
    "You can also notice that this federated training did not harm the performance of the model as both losses reduced at each epoch as expected and the final AUC score on the test data was above 97.5%.\n",
    "\n",
    "There is however one limitation of this method: by getting the model back we can still have access to some private information. \n",
    "Let's say Bob had only one SMS on his machine. When we get the model back, we can just check which embeddings of the model changed and we will know which were the tokens (words) of the SMS.\n",
    "\n",
    "In order to address this issue, there are two solutions: Differential Privacy and Secured Multi-Party Computation (SMPC).\n",
    "\n",
    "Differential Privacy would be used to make sure the model does not give access to some private information.\n",
    "\n",
    "SMPC, which is one kind of Encrypted Computation, in return allows you to send the model privately so that the remote workers which have the data cannot see the weights you are using."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
